#include <gtest/gtest.h>
#include "c/ddk/graph/handle_types.h"
#include "c/ddk/graph/operator.h"
#include "c/ddk/graph/attr_value.h"
#include "graph/csrc/resource_manager.h"
#include "c/ddk/graph/context.h"
#include "graph/types.h"
#include "graph/operator.h"
#include "graph/shape.h"

#include "framework/graph/utils/op_desc_utils.h"
#include "framework/graph/utils/attr_utils.h"

using namespace ge;
using namespace std;
using namespace hiai;

class UTEST_c_operator : public testing::Test {
public:
    void SetUp()
    {
        resMgr_ = HIAI_IR_ResourceManagerCreate();
    }

    void TearDown()
    {
        HIAI_IR_ResourceManagerDestroy(&resMgr_);
    }
public:
    OpHandle CreateData()
    {
        const char* name = "x";
        uint32_t shapeNum = 4;
        OpHandle dataOp = HIAI_IR_CreatePlaceHodler(
            resMgr_, name, HIAI_Format::HIAI_FORMAT_NCHW, HIAI_DATATYPE_FLOAT32, inputShape_, shapeNum);
        return dataOp;
    }

    ResMgrHandle resMgr_ = nullptr;
    int64_t inputShape_[4] = {10, 4, 120, 440};
};

TEST_F(UTEST_c_operator, CreatePlaceHodler)
{
    OpHandle dataOp = CreateData();
    EXPECT_NE(dataOp, nullptr);

    auto resMgrPtr = reinterpret_cast<ResourceManager*>(resMgr_);
    BasePtr inputOpBasePtr = resMgrPtr->GetSrcPtr(dataOp);
    
    std::shared_ptr<Operator> dataOpPtr = std::dynamic_pointer_cast<Operator>(inputOpBasePtr);
    EXPECT_EQ("x", dataOpPtr->GetName());
    EXPECT_EQ("Data", dataOpPtr->GetType());
    
    TensorDesc tensotDesc = dataOpPtr->GetInputDesc(0);
    ge::Shape shape = tensotDesc.GetShape();
    for (int i = 0; i < 4; i++) {
        EXPECT_EQ(inputShape_[i], shape.GetDim(i));
    }
    EXPECT_EQ(HIAI_Format::HIAI_FORMAT_NCHW, tensotDesc.GetFormat());
    EXPECT_EQ(ge::DataType::DT_FLOAT, tensotDesc.GetDataType());
}

TEST_F(UTEST_c_operator, CreateOp)
{
    OpHandle dataOpPtr = CreateData();
    HIAI_IR_Input input[1];
    input[0].op = dataOpPtr;
    input[0].outIndex = 0;

    AttrHandle modeAttr = HIAI_IR_CreateInt64Attr(resMgr_, 2); // mode

    HIAI_IR_Params params[1];
    params[0].name = "mode";
    params[0].attrValue = modeAttr;

    HIAI_IR_OpConfig opConfig;
    opConfig.opType = "Activation";
    opConfig.opName = "activation";
    opConfig.input = input;
    opConfig.inputNums = 1;
    opConfig.params = params;
    opConfig.paramsNums = 1;

    OpHandle actOp = HIAI_IR_CreateOp(resMgr_, &opConfig);
    ASSERT_NE(actOp, nullptr);

    auto resMgrPtr = reinterpret_cast<ResourceManager*>(resMgr_);
    BasePtr inputOpBasePtr = resMgrPtr->GetSrcPtr(actOp);
    
    std::shared_ptr<Operator> actOpPtr = std::dynamic_pointer_cast<Operator>(inputOpBasePtr);
    ASSERT_NE(actOpPtr, nullptr);
    EXPECT_EQ("activation", actOpPtr->GetName());
    EXPECT_EQ("Activation", actOpPtr->GetType());

    OpDescPtr opDescPtr = OpDescUtils::GetOpDescFromOperator(*actOpPtr);
    int64_t value = 0;
    ge::AttrUtils::GetInt(opDescPtr, "mode", value);
    EXPECT_EQ(value, 2);
}

TEST_F(UTEST_c_operator, createOpWithDynamicOutputOP)
{
    // data
    OpHandle dataOpPtr = CreateData();
    // split
    HIAI_IR_Input input[1];
    input[0].op = dataOpPtr;
    input[0].outIndex = 0;

    AttrHandle splitDimAttr = HIAI_IR_CreateInt64Attr(resMgr_, 2); // split_dim
    HIAI_IR_Params params[2];
    params[0].name = "split_dim";
    params[0].attrValue = splitDimAttr;

    AttrHandle numSplitAttr = HIAI_IR_CreateInt64Attr(resMgr_, 3); // num_split
    params[1].name = "num_split";
    params[1].attrValue = numSplitAttr;

    HIAI_IR_OpConfig opConfig;
    opConfig.opType = "SplitD";
    opConfig.opName = "split";
    opConfig.input = input;
    opConfig.inputNums = 1;
    opConfig.params = params;
    opConfig.paramsNums = 2;
    opConfig.outputNums = 2;

    OpHandle splitOp = HIAI_IR_CreateOp(resMgr_, &opConfig);
    EXPECT_NE(splitOp, nullptr);

    auto resMgrPtr = reinterpret_cast<ResourceManager*>(resMgr_);
    BasePtr splitOpBasePtr = resMgrPtr->GetSrcPtr(splitOp);
    std::shared_ptr<Operator> splitOpPtr = std::dynamic_pointer_cast<Operator>(splitOpBasePtr);

    OpDescPtr splitOpDescPtr = OpDescUtils::GetOpDescFromOperator(*splitOpPtr);
    int64_t splitDim = 0;
    int64_t numSplit = 0;
    ge::AttrUtils::GetInt(splitOpDescPtr, "split_dim", splitDim);
    ge::AttrUtils::GetInt(splitOpDescPtr, "num_split", numSplit);
    EXPECT_EQ(splitDim, 2);
    EXPECT_EQ(numSplit, 3);

    // add
    HIAI_IR_Input addInputOpPair[2];
    addInputOpPair[0].op = splitOp;
    addInputOpPair[0].outIndex = 0;

    addInputOpPair[1].op = splitOp;
    addInputOpPair[1].outIndex = 1;

    HIAI_IR_OpConfig addOpConfig;
    addOpConfig.opType = "Add";
    addOpConfig.opName = "add";
    addOpConfig.input = addInputOpPair;
    addOpConfig.inputNums = 2;
    addOpConfig.paramsNums = 0;
    OpHandle addOp = HIAI_IR_CreateOp(resMgr_, &addOpConfig);
    EXPECT_NE(addOp, nullptr);

    BasePtr addOpBasePtr = resMgrPtr->GetSrcPtr(addOp);
    std::shared_ptr<Operator> addOpPtr = std::dynamic_pointer_cast<Operator>(addOpBasePtr);
    EXPECT_EQ("add", addOpPtr->GetName());
    EXPECT_EQ("Add", addOpPtr->GetType());
}